{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Match Pyramid\n",
    "\n",
    "参考实现：\n",
    "\n",
    "* [MatchPyramid-for-semantic-matching](https://github.com/ddddwy/MatchPyramid-for-semantic-matching/blob/master/match_pyramid.py)\n",
    "* [MatchZoo](https://github.com/NTMC-Community/MatchZoo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -q tensorflow==2.0.0a\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.0-alpha0'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    'query_max_len': 20,\n",
    "    'doc_max_len': 100,\n",
    "    'num_conv_layers': 2,\n",
    "    'filters': [8, 16, 32],\n",
    "    'kernel_size': [[5, 5], [3, 3], [3, 3]],\n",
    "    'pool_size': [[2, 2], [2, 2], [2, 2]],\n",
    "    'dropout': 0.5,\n",
    "    'batch_size': 32,\n",
    "    'vocab_size': 100,\n",
    "    'embedding_size': 128,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def build_model(config):\n",
    "    q_input = tf.keras.layers.Input(shape=(config['query_max_len'],), name='q_input')\n",
    "    d_input = tf.keras.layers.Input(shape=(config['doc_max_len'],), name='d_input')\n",
    "    \n",
    "    embedding = tf.keras.layers.Embedding(config['vocab_size'], config['embedding_size'], name='embedding')\n",
    "    \n",
    "    q_embedding = embedding(q_input)\n",
    "    d_embedding = embedding(d_input)\n",
    "    \n",
    "    #dot\n",
    "    dot = tf.keras.layers.Dot(axes=-1, name='dot')([q_embedding, d_embedding])\n",
    "    print('dot shape: ', dot.shape)\n",
    "    matrix = tf.keras.layers.Reshape((config['query_max_len'], config['doc_max_len'], 1), name='matrix')(dot)\n",
    "    print('matrix shape: ', matrix.shape)\n",
    "    \n",
    "    x = matrix\n",
    "    for i in range(config['num_conv_layers']):\n",
    "        x = tf.keras.layers.Conv2D(\n",
    "            filters=config['filters'][i], \n",
    "            kernel_size=config['kernel_size'][i], \n",
    "            padding='same',\n",
    "            activation='relu')(x)\n",
    "        x = tf.keras.layers.MaxPooling2D(pool_size=tuple(config['pool_size'][i]))(x)\n",
    "    \n",
    "    flatten = tf.keras.layers.Flatten()(x)\n",
    "    drop = tf.keras.layers.Dropout(config['dropout'])(flatten)\n",
    "    dense = tf.keras.layers.Dense(32, activation='relu')(drop)\n",
    "    out = tf.keras.layers.Dense(1, activation='sigmoid', name='out')(dense)\n",
    "    \n",
    "    model = tf.keras.Model(inputs=[q_input, d_input], outputs=[matrix, out])\n",
    "    return model\n",
    "                                                  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0508 17:10:30.348741 140249640052544 training_utils.py:1152] Output matrix missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to matrix.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dot shape:  (None, 20, 100)\n",
      "matrix shape:  (None, 20, 100, 1)\n",
      "Model: \"model_7\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "q_input (InputLayer)            [(None, 20)]         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "d_input (InputLayer)            [(None, 100)]        0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             12800       q_input[0][0]                    \n",
      "                                                                 d_input[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot (Dot)                       (None, 20, 100)      0           embedding[0][0]                  \n",
      "                                                                 embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "matrix (Reshape)                (None, 20, 100, 1)   0           dot[0][0]                        \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_29 (Conv2D)              (None, 20, 100, 8)   208         matrix[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_21 (MaxPooling2D) (None, 10, 50, 8)    0           conv2d_29[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_30 (Conv2D)              (None, 10, 50, 16)   1168        max_pooling2d_21[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_22 (MaxPooling2D) (None, 5, 25, 16)    0           conv2d_30[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "flatten_7 (Flatten)             (None, 2000)         0           max_pooling2d_22[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "dropout_7 (Dropout)             (None, 2000)         0           flatten_7[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_13 (Dense)                (None, 32)           64032       dropout_7[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "out (Dense)                     (None, 1)            33          dense_13[0][0]                   \n",
      "==================================================================================================\n",
      "Total params: 78,241\n",
      "Trainable params: 78,241\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = build_model(config)\n",
    "\n",
    "model.summary()\n",
    "\n",
    "model.compile(loss={'out': 'binary_crossentropy'}, optimizer='sgd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "vocabulary_file must be specified and must not be empty.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-43-794319138798>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mdoc_vocab_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mquery_str2id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlookup_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_table_from_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery_vocab_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0mquery_id2str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlookup_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_to_string_table_from_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery_vocab_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'unk'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mdoc_str2id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlookup_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_table_from_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdoc_vocab_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/machine-learning-notes/lib/python3.6/site-packages/tensorflow/python/ops/lookup_ops.py\u001b[0m in \u001b[0;36mindex_table_from_file\u001b[0;34m(vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec, key_dtype, name, key_column_index, value_column_index, delimiter)\u001b[0m\n\u001b[1;32m   1233\u001b[0m   if vocabulary_file is None or (\n\u001b[1;32m   1234\u001b[0m       isinstance(vocabulary_file, six.string_types) and not vocabulary_file):\n\u001b[0;32m-> 1235\u001b[0;31m     \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"vocabulary_file must be specified and must not be empty.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1236\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mnum_oov_buckets\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1237\u001b[0m     raise ValueError(\"num_oov_buckets must be greater or equal than 0, got %d.\"\n",
      "\u001b[0;31mValueError\u001b[0m: vocabulary_file must be specified and must not be empty."
     ]
    }
   ],
   "source": [
    "from tensorflow.python.ops import lookup_ops\n",
    "\n",
    "query_vocab_file = ''\n",
    "doc_vocab_file = ''\n",
    "\n",
    "query_str2id = lookup_ops.index_table_from_file(query_vocab_file, default_value=0)\n",
    "query_id2str = lookup_ops.index_to_string_table_from_file(query_vocab_file, default_value='unk')\n",
    "doc_str2id = lookup_ops.index_table_from_file(doc_vocab_file, default_value=0)\n",
    "doc_id2str = lookup_ops.index_to_string_table_from_file(doc_vocab_file, default_value='unk')\n",
    "\n",
    "unk_id = tf.constant(0, dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = {\n",
    "    'shuffle_size': 10000000,\n",
    "    'num_parallel_calls': 4,\n",
    "    'query_max_len': 20,\n",
    "    'doc_max_len': 100,\n",
    "    'batch_size': 32,\n",
    "    'predict_batch_size': 32\n",
    "}\n",
    "\n",
    "def _common_process_dataset(dataset, config):\n",
    "    dataset = dataset.shuffle(config['shuffle_size'])\n",
    "    dataset = dataset.map(\n",
    "        lambda x: (tf.string_split([x], delimiter='@').values[0],\n",
    "                   tf.string_split([x], delimiter='@').values[1],\n",
    "                   tf.string_split([x], delimiter='@').values[2]),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    dataset = dataset.map(\n",
    "        lambda q, d, l: (tf.string_split([q], delimiter=' ').values,\n",
    "                         tf.string_split([d], delimiter=' ').values,\n",
    "                         tf.string_to_number(l, out_type=tf.int32)),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    \n",
    "    dataset = dataset.map(\n",
    "        lambda q, d, l: (q[:config['query_max_len']], d[:config['doc_max_len']], l),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    \n",
    "    dataset = dataset.map(\n",
    "        lambda q, d, l: (query_str2id.lookup(q), doc_str2id.lookup(d), l),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "def _build_dataset(dataset, config):\n",
    "    dataset = _common_process_dataset(dataset, config)\n",
    "    dataset = dataset.padded_batch(\n",
    "        batch_size=config['batch_size'],\n",
    "        padding_shapes=(tf.Dimension(config['query_max_len']),\n",
    "                        tf.Dimension(config['doc_max_len']),\n",
    "                        []),\n",
    "        padding_values=(unk_id, unk_id, 0)\n",
    "    )\n",
    "    dataset = dataset.map(\n",
    "        lambda q, d, l: ((q, d), l),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "def build_train_dataset(train_files, config):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(train_files)\n",
    "    dataset = dataset.flat_map(lambda x: tf.data.TextLineDataset(x).skip(config.get(''skip_count, 0)))\n",
    "    dataset = _build_dataset(dataset, config)\n",
    "    return dataset\n",
    "\n",
    "def build_eval_dataset(eval_files, config):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(eval_files)\n",
    "    dataset = dataset.flat_map(lambda x: tf.data.TextLineDataset(x))\n",
    "    dataset = _build_dataset(dataset)\n",
    "    return dataset\n",
    "\n",
    "def build_predict_dataset(predict_files, config):\n",
    "    \"\"\"假设predict文件也带有label\"\"\"\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(predict_files)\n",
    "    dataset = dataset.flat_map(lambda x: tf.data.TextLineDataset(x).skip(config.get('skip_count', 0)))\n",
    "    dataset = _commone_process_dataset(dataset)\n",
    "    dataset = dataset.map(\n",
    "        lambda q, d, l: (q, d),\n",
    "        num_parallel_calls=config['num_parallel_calls']\n",
    "    ).prefetch(tf.data.experimental.AUTO_TUNE)\n",
    "    \n",
    "    dataset = dataset.padded_batch(\n",
    "        batch_size=config['predict_batch_size'],\n",
    "        padding_shapes=(tf.Dimension(config['query_max_len']),\n",
    "                        tf.Dimension(config['doc_max_len'])),\n",
    "        paddding_values=(unk_id, unk_id)\n",
    "    )\n",
    "    dataset = dataset.map(lambda q, d: ((q, d)))\n",
    "    return dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train and eval model\n",
    "import os\n",
    "\n",
    "train_dataset = build_train_dataset(dataset_config)\n",
    "eval_dataset = build_eval_dataset(dataset_config)\n",
    "\n",
    "model_dir = '/opt/algo_nfs/kdd_luozhouyang/models/match_pyramid/20190508'\n",
    "if not os.path.exists(model_dir):\n",
    "    os.mkdirs(model_dir)\n",
    "    \n",
    "ckpt_path = os.path.join(model_dir, 'mp-{epoch:04d}.ckpt')\n",
    "log_dir = os.path.join(model_dir, 'log')\n",
    "\n",
    "callbacks = [\n",
    "    tf.keras.callbacks.ModelCheckpoint(ckpt_path, save_weights_only=True,),\n",
    "    tf.keras.callabcks.TensorBoard(log_dir),\n",
    "    tf.keras.callbacks.EarlyStopping(patience=10)\n",
    "]\n",
    "\n",
    "model.fit(train_dataset, validation_data=eval_dataset, callbacks=callbacks)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test model\n",
    "\n",
    "predict_dataset = build_predict_dataset(dataset_config)\n",
    "\n",
    "results = model.predict(predict_dataset)\n",
    "matrix, out = results['matrix'], results['out']\n",
    "\n",
    "print(matrix.numpy())\n",
    "print(out.numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:machine-learning-notes] *",
   "language": "python",
   "name": "conda-env-machine-learning-notes-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
